-
Notifications
You must be signed in to change notification settings - Fork 576
fix(finetune): calculate fitting stat when using random fitting in finetuning process #4928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
fix(finetune): calculate fitting stat when using random fitting in finetuning process #4928
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughAdds a public hook compute_fitting_input_stat to BaseAtomicModel and DPAtomicModel, wires it into model creation when bias_adjust_mode == "set-by-statistic", implements compute_input_stats in GeneralFitting, and updates tests to set model.data_stat_nbatch and extend state-dict comparisons to include fparam/aparam. Changes
Sequence Diagram(s)sequenceDiagram
participant MakeModel as make_model
participant AtomicModel as atomic_model
participant FittingNet as fitting_net
MakeModel->>AtomicModel: change_out_bias(merged, bias_adjust_mode)
alt bias_adjust_mode == "set-by-statistic"
MakeModel->>AtomicModel: compute_fitting_input_stat(merged)
AtomicModel->>FittingNet: compute_input_stats(sample_merged, protection)
FittingNet-->>AtomicModel: updated stats (fparam_avg, fparam_inv_std, aparam_avg, aparam_inv_std)
else
Note right of MakeModel: no fitting input-stat computation
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4928 +/- ##
=======================================
Coverage 84.23% 84.23%
=======================================
Files 709 709
Lines 70078 70092 +14
Branches 3619 3619
=======================================
+ Hits 59032 59044 +12
- Misses 9880 9883 +3
+ Partials 1166 1165 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
deepmd/pd/model/atomic_model/base_atomic_model.py(1 hunks)deepmd/pd/model/atomic_model/dp_atomic_model.py(1 hunks)deepmd/pd/model/model/make_model.py(1 hunks)deepmd/pt/model/atomic_model/base_atomic_model.py(1 hunks)deepmd/pt/model/atomic_model/dp_atomic_model.py(2 hunks)deepmd/pt/model/model/make_model.py(1 hunks)source/tests/pd/test_training.py(2 hunks)source/tests/pt/test_training.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/pt/model/atomic_model/dp_atomic_model.pydeepmd/pt/model/atomic_model/base_atomic_model.pysource/tests/pd/test_training.pydeepmd/pd/model/atomic_model/dp_atomic_model.pydeepmd/pd/model/atomic_model/base_atomic_model.pydeepmd/pd/model/model/make_model.pydeepmd/pt/model/model/make_model.pysource/tests/pt/test_training.py
🧠 Learnings (2)
📚 Learning: 2025-09-18T11:37:10.532Z
Learnt from: CR
Repo: deepmodeling/deepmd-kit PR: 0
File: AGENTS.md:0-0
Timestamp: 2025-09-18T11:37:10.532Z
Learning: Applies to source/tests/tf/test_dp_test.py : Keep the core TensorFlow test `source/tests/tf/test_dp_test.py` passing; use it for quick validation
Applied to files:
source/tests/pd/test_training.pysource/tests/pt/test_training.py
📚 Learning: 2024-09-19T04:25:12.408Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Applied to files:
source/tests/pd/test_training.pysource/tests/pt/test_training.py
🧬 Code graph analysis (6)
deepmd/pt/model/atomic_model/dp_atomic_model.py (3)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(406-424)deepmd/pt/model/atomic_model/base_atomic_model.py (1)
compute_fitting_stat(496-512)deepmd/pt/model/task/fitting.py (1)
compute_input_stats(78-157)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)
deepmd/pd/model/atomic_model/dp_atomic_model.py (3)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
compute_fitting_stat(518-534)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)deepmd/pd/model/task/fitting.py (1)
compute_input_stats(75-160)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(406-424)
deepmd/pd/model/model/make_model.py (3)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
compute_fitting_stat(518-534)deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(406-424)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)
deepmd/pt/model/model/make_model.py (2)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
compute_fitting_stat(496-512)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
🔇 Additional comments (10)
deepmd/pd/model/model/make_model.py (1)
231-232: LGTM! Correct invocation of compute_fitting_stat.The conditional call to
compute_fitting_statafterchange_out_biasappropriately addresses the PR objective of computing fitting statistics when using random fitting in finetuning (set-by-statistic mode). The implementation correctly reuses the merged data.source/tests/pd/test_training.py (2)
92-100: LGTM! Appropriate test assertion for fparam/aparam keys.The broadened exclusion condition correctly validates that fparam and aparam statistics are preserved during random finetuning, aligning with the PR's objective to compute fitting statistics properly.
197-197: LGTM! Configuration for data statistics batching.Adding
data_stat_nbatch = 100appropriately exercises the data statistics batching behavior that's central to this PR's fitting statistics computation.deepmd/pt/model/atomic_model/base_atomic_model.py (1)
496-512: LGTM! Appropriate placeholder for PT base atomic model.The no-op implementation is correct for the base class, allowing derived classes to provide concrete implementations. The documentation correctly references
torch.Tensorfor the PyTorch path.deepmd/pt/model/model/make_model.py (1)
235-236: LGTM! Correct invocation of compute_fitting_stat in PT path.The implementation mirrors the PD path and correctly invokes
compute_fitting_statwhen bias_adjust_mode is "set-by-statistic", addressing the PR's objective for the PyTorch path.source/tests/pt/test_training.py (2)
95-103: LGTM! Test assertions align with PD path.The broadened exclusion condition for fparam/aparam keys correctly validates the new fitting statistics computation during random finetuning in the PyTorch path.
263-263: LGTM! Configuration mirrors PD test setup.Setting
data_stat_nbatch = 100appropriately exercises data statistics batching in the PyTorch path, consistent with the PD tests.deepmd/pt/model/atomic_model/dp_atomic_model.py (3)
8-8: LGTM! Union import for type hints.The Union import is correctly added to support the type hints for the new
compute_fitting_statmethod signature.
332-332: LGTM! Refactored to use compute_fitting_stat.Good refactoring that centralizes fitting statistics computation through the new
compute_fitting_statmethod, improving code organization and maintainability.
336-354: LGTM! Proper implementation of compute_fitting_stat.The method correctly delegates to
fitting_net.compute_input_statswith thedata_stat_protectparameter, providing a clean interface for computing fitting statistics from packed data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pd/model/atomic_model/base_atomic_model.py(1 hunks)deepmd/pt/model/atomic_model/dp_atomic_model.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/pd/model/atomic_model/base_atomic_model.pydeepmd/pt/model/atomic_model/dp_atomic_model.py
🧬 Code graph analysis (2)
deepmd/pd/model/atomic_model/base_atomic_model.py (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(406-424)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
compute_fitting_stat(496-512)deepmd/pt/model/task/fitting.py (1)
compute_input_stats(78-157)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (28)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (rocm, rocm)
🔇 Additional comments (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
8-8: LGTM: Import addition supports new method signature.The
Unionimport is necessary for theUnion[Callable[[], list[dict]], list[dict]]type hint in the newcompute_fitting_statmethod.
332-332: Good refactoring to use the public API.Calling
self.compute_fitting_stat(wrapped_sampler)instead of directly accessingself.fitting_net.compute_input_statsestablishes a consistent public interface and enables proper encapsulation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
496-512: LGTM! Clean API extension for computing fitting statistics.The method provides a well-documented hook for derived classes to compute fitting statistics during finetuning. The no-op default (
pass) is appropriate since not all atomic model types require fitting statistics computation.Optional: Consider adding a clarifying comment
To make the intent clearer, you could add a brief comment before the
passstatement:the lazy function helps by only sampling once. """ + # No-op in base class; derived classes override if fitting statistics are needed. passReminder: Run code quality checks
As per coding guidelines, ensure you run the following before committing:
ruff check . ruff format .
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
deepmd/pd/model/atomic_model/base_atomic_model.py(1 hunks)deepmd/pd/model/atomic_model/dp_atomic_model.py(1 hunks)deepmd/pt/model/atomic_model/base_atomic_model.py(1 hunks)deepmd/pt/model/atomic_model/dp_atomic_model.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- deepmd/pd/model/atomic_model/base_atomic_model.py
- deepmd/pt/model/atomic_model/dp_atomic_model.py
- deepmd/pd/model/atomic_model/dp_atomic_model.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/pt/model/atomic_model/base_atomic_model.py
🧬 Code graph analysis (1)
deepmd/pt/model/atomic_model/base_atomic_model.py (2)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(406-424)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
|
Limitation: the imput stat is not implemented in the python backend. |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
source/tests/common/dpmodel/test_fitting_stat.py (1)
14-66: Consider renaming helper functions to remove_ptsuffix.The helper functions
_make_fake_data_pt,_brute_fparam_pt, and_brute_aparam_ptuse a_ptsuffix that typically denotes PyTorch-related code. However, these functions are in thedpmodeltest directory and use NumPy arrays, not PyTorch tensors.To avoid confusion, consider renaming them to remove the
_ptsuffix:
_make_fake_data_pt→_make_fake_data_brute_fparam_pt→_brute_fparam_brute_aparam_pt→_brute_aparamApply this diff to rename the functions:
-def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds): +def _make_fake_data(sys_natoms, sys_nframes, avgs, stds): merged_output_stat = [] nsys = len(sys_natoms) ndof = len(avgs) for ii in range(nsys): sys_dict = {} tmp_data_f = [] tmp_data_a = [] for jj in range(ndof): rng = np.random.default_rng(2025 * ii + 220 * jj) tmp_data_f.append( rng.normal(loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], 1)) ) rng = np.random.default_rng(220 * ii + 1636 * jj) tmp_data_a.append( rng.normal( loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], sys_natoms[ii]) ) ) tmp_data_f = np.transpose(tmp_data_f, (1, 2, 0)) tmp_data_a = np.transpose(tmp_data_a, (1, 2, 0)) sys_dict["fparam"] = tmp_data_f sys_dict["aparam"] = tmp_data_a merged_output_stat.append(sys_dict) return merged_output_stat -def _brute_fparam_pt(data, ndim): +def _brute_fparam(data, ndim): adata = [ii["fparam"] for ii in data] all_data = [] for ii in adata: tmp = np.reshape(ii, [-1, ndim]) if len(all_data) == 0: all_data = np.array(tmp) else: all_data = np.concatenate((all_data, tmp), axis=0) avg = np.average(all_data, axis=0) std = np.std(all_data, axis=0) return avg, std -def _brute_aparam_pt(data, ndim): +def _brute_aparam(data, ndim): adata = [ii["aparam"] for ii in data] all_data = [] for ii in adata: tmp = np.reshape(ii, [-1, ndim]) if len(all_data) == 0: all_data = np.array(tmp) else: all_data = np.concatenate((all_data, tmp), axis=0) avg = np.average(all_data, axis=0) std = np.std(all_data, axis=0) return avg, stdThen update the test to use the renamed functions:
def test(self) -> None: descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) fitting = EnergyFittingNet( descrpt.get_ntypes(), descrpt.get_dim_out(), neuron=[240, 240, 240], resnet_dt=True, numb_fparam=3, numb_aparam=3, ) avgs = [0, 10, 100] stds = [2, 0.4, 0.00001] sys_natoms = [10, 100] sys_nframes = [5, 2] - all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds) - frefa, frefs = _brute_fparam_pt(all_data, len(avgs)) - arefa, arefs = _brute_aparam_pt(all_data, len(avgs)) + all_data = _make_fake_data(sys_natoms, sys_nframes, avgs, stds) + frefa, frefs = _brute_fparam(all_data, len(avgs)) + arefa, arefs = _brute_aparam(all_data, len(avgs)) fitting.compute_input_stats(all_data, protection=1e-2) frefs_inv = 1.0 / frefs arefs_inv = 1.0 / arefs frefs_inv[frefs_inv > 100] = 100 arefs_inv[arefs_inv > 100] = 100 np.testing.assert_almost_equal(frefa, fitting.fparam_avg) np.testing.assert_almost_equal( frefs_inv, fitting.fparam_inv_std ) np.testing.assert_almost_equal(arefa, fitting.aparam_avg) np.testing.assert_almost_equal( arefs_inv, fitting.aparam_inv_std )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/dpmodel/fitting/general_fitting.py(2 hunks)deepmd/pd/model/atomic_model/base_atomic_model.py(1 hunks)source/tests/common/dpmodel/test_fitting_stat.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
source/tests/common/dpmodel/test_fitting_stat.py (3)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
get_ntypes(628-629)source/tests/tf/common.py (1)
numb_aparam(909-910)deepmd/dpmodel/fitting/general_fitting.py (1)
compute_input_stats(225-288)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_input_stat(404-422)
deepmd/dpmodel/fitting/general_fitting.py (1)
deepmd/pt/model/task/fitting.py (1)
compute_input_stats(78-157)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
🔇 Additional comments (4)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
518-534: LGTM! Placeholder method correctly defines the public API.The method signature and docstring are well-structured. The placeholder implementation follows the correct pattern for a base class that will be overridden by derived classes (e.g., DPAtomicModel).
The docstring correctly references
paddle.Tensorfor the PaddlePaddle framework and clearly describes the purpose of computing input statistics from packed data.deepmd/dpmodel/fitting/general_fitting.py (2)
7-7: LGTM! Import addition supports the new method signature.The
Callableimport is necessary for the type hint incompute_input_statsmethod.
225-289: Excellent implementation with numerically stable variance calculation.The method correctly handles both frame and atomic parameters:
- Early return optimization: Skips computation when no parameters are present (lines 245-247).
- Lazy evaluation support: Properly handles both immediate and deferred data sources (lines 248-251).
- Protection threshold: Prevents division by zero by clamping small standard deviations to the protection value before computing inverse std (lines 258-262, 281-285).
- Numerically stable aparam computation: Uses parallel aggregation with sum and sum-of-squares for variance calculation (lines 267-280), which is more stable than the direct approach used for fparam.
The implementation is consistent with the PyTorch counterpart in
deepmd/pt/model/task/fitting.py.source/tests/common/dpmodel/test_fitting_stat.py (1)
69-99: Well-designed test that validates protection threshold behavior.The test effectively validates the
compute_input_statsimplementation:
- Comprehensive test data: Uses multiple systems with varying numbers of atoms and frames.
- Protection threshold validation: The synthetic data includes
stds=[2, 0.4, 0.00001]where the third value (0.00001 << 0.01) specifically tests the protection threshold mechanism.- Correct reference calculation: Lines 88-91 properly compute reference inverse standard deviations and clip them at 100, which corresponds to the protection threshold (1.0 / 0.01 = 100).
- Proper assertions: Validates both mean and inverse std for both fparam and aparam.
The test ensures that when standard deviation is below the protection threshold, the inverse std is correctly capped to prevent numerical instability.
|
I add |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
source/tests/common/dpmodel/test_fitting_stat.py (2)
41-66: Reduce duplication between_brute_fparam_ptand_brute_aparam_ptThe two brute-force helpers are identical except for the dictionary key they read, which is a good sign but also duplication. You can simplify and make future changes safer by factoring them into a single generic helper.
For example:
-def _brute_fparam_pt(data, ndim): - adata = [ii["fparam"] for ii in data] - all_data = [] - for ii in adata: - tmp = np.reshape(ii, [-1, ndim]) - if len(all_data) == 0: - all_data = np.array(tmp) - else: - all_data = np.concatenate((all_data, tmp), axis=0) - avg = np.average(all_data, axis=0) - std = np.std(all_data, axis=0) - return avg, std - - -def _brute_aparam_pt(data, ndim): - adata = [ii["aparam"] for ii in data] - all_data = [] - for ii in adata: - tmp = np.reshape(ii, [-1, ndim]) - if len(all_data) == 0: - all_data = np.array(tmp) - else: - all_data = np.concatenate((all_data, tmp), axis=0) - avg = np.average(all_data, axis=0) - std = np.std(all_data, axis=0) - return avg, std +def _brute_param_pt(data, key, ndim): + chunks = [np.reshape(d[key], [-1, ndim]) for d in data] + all_data = np.concatenate(chunks, axis=0) + avg = np.average(all_data, axis=0) + std = np.std(all_data, axis=0) + return avg, std + + +def _brute_fparam_pt(data, ndim): + return _brute_param_pt(data, "fparam", ndim) + + +def _brute_aparam_pt(data, ndim): + return _brute_param_pt(data, "aparam", ndim)This keeps the intent clear and avoids having to update two almost-identical functions in the future.
69-95: Test logic correctly mirrors protection/clipping, but consider more robust numeric toleranceThe main test correctly mirrors
compute_input_stats’ behavior:
- You use population standard deviation (
np.std(..., ddof=0)implicitly) like the implementation.- The
protectionbehavior is matched by computing1.0 / stdand clipping inv-std at100whenprotection=1e-2, which is equivalent to clampingstdto>= protectionbefore inversion.- You validate both fparam and aparam averages and inverse stddevs across multiple systems and shapes, which gives good coverage of the new path.
One potential improvement is numeric robustness: if
EnergyFittingNetstores stats in a lower-precision dtype (e.g., float32),np.testing.assert_almost_equalwith the defaultdecimal=7can be a bit tight and lead to flaky failures. You could either relax the tolerance slightly or use explicitrtol/atolviaassert_allclose:- np.testing.assert_almost_equal(frefa, fitting.fparam_avg) - np.testing.assert_almost_equal(frefs_inv, fitting.fparam_inv_std) - np.testing.assert_almost_equal(arefa, fitting.aparam_avg) - np.testing.assert_almost_equal(arefs_inv, fitting.aparam_inv_std) + np.testing.assert_allclose(frefa, fitting.fparam_avg, rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(frefs_inv, fitting.fparam_inv_std, rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(arefa, fitting.aparam_avg, rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(arefs_inv, fitting.aparam_inv_std, rtol=1e-5, atol=1e-7)This keeps the assertion strict enough to catch real regressions while making the test less sensitive to dtype or minor numerical differences.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/common/dpmodel/test_fitting_stat.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
source/tests/common/dpmodel/test_fitting_stat.py (1)
deepmd/dpmodel/fitting/general_fitting.py (1)
compute_input_stats(225-288)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (28)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
🔇 Additional comments (1)
source/tests/common/dpmodel/test_fitting_stat.py (1)
14-38: Data generator shapes line up correctly withcompute_input_statsThe synthetic data generator is consistent with
GeneralFitting.compute_input_statsexpectations:fparamandaparamboth have the fitting dimension as the last axis, and your upstream reshape to[-1, ndim]will work for both. Using per-parameterdefault_rngseeds makes the test deterministic and nicely exercises multiple systems/frames.No changes needed here from a correctness standpoint.
In finetuing process, the computation of fitting stat is skipped in previous code. There are two situations:
fparamoraparamwhich has the same meaning of finetuning task. The keyfparam_avg/fparam_inv_std/aparam_avg/aparam_inv_stdload from the pretrained model. It is correct.Summary by CodeRabbit
New Features
Tests